Convolutional LSTM

Tutorial on how to train a convolutional neural network with a bidirectional LSTM to predict protein subcellular localization.


In [2]:
# Import all the necessary modules
import os
import sys
os.environ["THEANO_FLAGS"] = "mode=FAST_RUN,optimizer=None,device=cpu,floatX=float32"
sys.path.insert(0,'..')
import numpy as np
import theano
import theano.tensor as T
import lasagne
from confusionmatrix import ConfusionMatrix
from utils import iterate_minibatches
import matplotlib.pyplot as plt
import time
import itertools
%matplotlib inline

Building the network

The first thing that we have to do is to define the network architecture. Here we are going to use an input layer, two convolutional layers, a bidirectional LSTM, a dense layer and an output layer. These are the steps that we are going to follow:

1.- Specify the hyperparameters of the network:


In [3]:
batch_size = 128
seq_len = 400
n_feat = 20
n_hid = 15
n_class = 10
lr = 0.0025
n_filt = 10
drop_prob = 0.5

2.- Define the input variables to our network:


In [4]:
# We use ftensor3 because the protein data is a 3D-matrix in float32 
input_var = T.ftensor3('inputs')
# ivector because the labels is a single dimensional vector of integers
target_var = T.ivector('targets')
# fmatrix because the masks to ignore the padded positions is a 2D-matrix in float32
mask_var = T.fmatrix('masks')
# Dummy data to check the size of the layers during the building of the network
X = np.random.randint(0,10,size=(batch_size,seq_len,n_feat)).astype('float32')
Xmask = np.ones((batch_size,seq_len)).astype('float32')

3.- Define the layers of the network:


In [5]:
# Input layer, holds the shape of the data
l_in = lasagne.layers.InputLayer(shape=(batch_size, None, n_feat), input_var=input_var, name='Input')
print('Input layer: {}'.format(
    lasagne.layers.get_output(l_in, inputs={l_in: input_var}).eval({input_var: X}).shape))

# Mask input layer
l_mask = lasagne.layers.InputLayer(shape=(batch_size, None), input_var=mask_var, name='Mask')
print('Mask layer: {}'.format(
    lasagne.layers.get_output(l_mask, inputs={l_mask: mask_var}).eval({mask_var: Xmask}).shape))

# Shuffle shape to be properly read by the CNN layer
l_shu = lasagne.layers.DimshuffleLayer(l_in, (0,2,1))
print('DimshuffleLayer layer: {}'.format(
    lasagne.layers.get_output(l_shu, inputs={l_in: input_var}).eval({input_var: X}).shape))

# Convolutional layers with different filter size
l_conv_a = lasagne.layers.Conv1DLayer(l_shu, num_filters=n_filt, pad='same', stride=1, 
                                      filter_size=3, nonlinearity=lasagne.nonlinearities.rectify)
print('Convolutional layer size 3: {}'.format(
    lasagne.layers.get_output(l_conv_a, inputs={l_in: input_var}).eval({input_var: X}).shape))

l_conv_b = lasagne.layers.Conv1DLayer(l_shu, num_filters=n_filt, pad='same', stride=1, 
                                      filter_size=5, nonlinearity=lasagne.nonlinearities.rectify)
print('Convolutional layer size 5: {}'.format(
    lasagne.layers.get_output(l_conv_b, inputs={l_in: input_var}).eval({input_var: X}).shape))

# The output is concatenated
l_conc = lasagne.layers.ConcatLayer([l_conv_a, l_conv_b], axis=1)
print('Concatenated convolutional layers: {}'.format(
    lasagne.layers.get_output(l_conc, inputs={l_in: input_var}).eval({input_var: X}).shape))

# Second CNN layer
l_conv_final = lasagne.layers.Conv1DLayer(l_conc, num_filters=n_filt*2, pad='same', 
                                          stride=1, filter_size=3, 
                                          nonlinearity=lasagne.nonlinearities.rectify)
print('Final convolutional layer: {}'.format(
    lasagne.layers.get_output(l_conv_final, inputs={l_in: input_var}).eval({input_var: X}).shape))

l_reshu = lasagne.layers.DimshuffleLayer(l_conv_final, (0,2,1))
print('Second DimshuffleLayer layer: {}'.format(
    lasagne.layers.get_output(l_reshu, inputs={l_in: input_var}).eval({input_var: X}).shape))

# Bidirectional LSTM layer, we only take the last hidden state (only_return_final)
l_fwd = lasagne.layers.LSTMLayer(l_reshu, num_units=n_hid, name='LSTMFwd', mask_input=l_mask, 
                                 only_return_final=True, nonlinearity=lasagne.nonlinearities.tanh)
l_bck = lasagne.layers.LSTMLayer(l_reshu, num_units=n_hid, name='LSTMBck', mask_input=l_mask, 
                                 only_return_final=True, backwards=True, nonlinearity=lasagne.nonlinearities.tanh)
print('Forward LSTM last hidden state: {}'.format(
    lasagne.layers.get_output(l_fwd, inputs={l_in: input_var, l_mask: mask_var}).eval(
        {input_var: X, mask_var:Xmask}).shape))
print('Backward LSTM hidden state: {}'.format(
    lasagne.layers.get_output(l_bck, inputs={l_in: input_var, l_mask: mask_var}).eval(
        {input_var: X, mask_var:Xmask}).shape))

# Concatenate both layers
l_conc_lstm = lasagne.layers.ConcatLayer([l_fwd, l_bck], axis=1)
print('Concatenated last hidden states: {}'.format(
    lasagne.layers.get_output(l_conc_lstm, inputs={l_in: input_var, l_mask: mask_var}).eval(
        {input_var: X, mask_var:Xmask}).shape))

# Dense layer with ReLu activation function
l_dense = lasagne.layers.DenseLayer(l_conc_lstm, num_units=n_hid*2, name="Dense",
                                    nonlinearity=lasagne.nonlinearities.rectify)
print('Dense layer: {}'.format(
    lasagne.layers.get_output(l_dense, inputs={l_in: input_var, l_mask: mask_var}).eval(
        {input_var: X, mask_var:Xmask}).shape))

# Output layer with a Softmax activation function. Note that we include a dropout layer
l_out = lasagne.layers.DenseLayer(lasagne.layers.dropout(l_dense, p=drop_prob), num_units=n_class, name="Softmax", 
                                  nonlinearity=lasagne.nonlinearities.softmax)
print('Output layer: {}'.format(
    lasagne.layers.get_output(l_out, inputs={l_in: input_var, l_mask: mask_var}).eval(
        {input_var: X, mask_var:Xmask}).shape))


Input layer: (128, 400, 20)
Mask layer: (128, 400)
DimshuffleLayer layer: (128, 20, 400)
Convolutional layer size 3: (128, 10, 400)
Convolutional layer size 5: (128, 10, 400)
Concatenated convolutional layers: (128, 20, 400)
Final convolutional layer: (128, 20, 400)
Second DimshuffleLayer layer: (128, 400, 20)
Forward LSTM last hidden state: (128, 15)
Backward LSTM hidden state: (128, 15)
Concatenated last hidden states: (128, 30)
Dense layer: (128, 30)
Output layer: (128, 10)

4.- Calculate the prediction and network loss for the training set and update the network weights:


In [6]:
# Get output training, deterministic=False is used for training
prediction = lasagne.layers.get_output(l_out, inputs={l_in: input_var, l_mask: mask_var}, deterministic=False)

# Calculate the categorical cross entropy between the labels and the prediction
t_loss = T.nnet.categorical_crossentropy(prediction, target_var)

# Training loss
loss = T.mean(t_loss)

# Parameters
params = lasagne.layers.get_all_params([l_out], trainable=True)

# Get the network gradients and perform total norm constraint normalization
all_grads = lasagne.updates.total_norm_constraint(T.grad(loss, params),3)

# Update parameters using ADAM 
updates = lasagne.updates.adam(all_grads, params, learning_rate=lr)

5.- Calculate the prediction and network loss for the validation set:


In [7]:
# Get output validation, deterministic=True is only use for validation
val_prediction = lasagne.layers.get_output(l_out, inputs={l_in: input_var, l_mask: mask_var}, deterministic=True)

# Calculate the categorical cross entropy between the labels and the prediction
t_val_loss = lasagne.objectives.categorical_crossentropy(val_prediction, target_var)

# Validation loss 
val_loss = T.mean(t_val_loss)

6.- Build theano functions:


In [8]:
# Build functions
train_fn = theano.function([input_var, target_var, mask_var], [loss, prediction], updates=updates)
val_fn = theano.function([input_var, target_var, mask_var], [val_loss, val_prediction])

Load dataset

Once that the network is built, the next step is to load the training and the validation set


In [9]:
# Load the encoded protein sequences, labels and masks
train = np.load('data/reduced_train.npz')
X_train = train['X_train']
y_train = train['y_train']
mask_train = train['mask_train']
print(X_train.shape)


(2423, 400, 20)

In [10]:
validation = np.load('data/reduced_val.npz')
X_val = validation['X_val']
y_val = validation['y_val']
mask_val = validation['mask_val']
print(X_val.shape)


(635, 400, 20)

Training

Once that the data is ready and the network compiled we can start with the training of the model. Here we define the number of epochs that we want to perform


In [11]:
# Number of epochs
num_epochs = 100

# Lists to save loss and accuracy of each epoch
loss_training = []
loss_validation = []
acc_training = []
acc_validation = []
start_time = time.time()
min_val_loss = float("inf")

# Start training 
for epoch in range(num_epochs):
    
    # Full pass training set
    train_err = 0
    train_batches = 0
    confusion_train = ConfusionMatrix(n_class)

    # Generate minibatches and train on each one of them
    for batch in iterate_minibatches(X_train.astype(np.float32), y_train.astype(np.int32), 
                                     mask_train.astype(np.float32), batch_size, shuffle=True):
        # Inputs to the network
        inputs, targets, in_masks = batch
        # Calculate loss and prediction
        tr_err, predict = train_fn(inputs, targets, in_masks)
        train_err += tr_err
        train_batches += 1
        # Get the predicted class, the one with the maximum likelihood
        preds = np.argmax(predict, axis=-1)
        confusion_train.batch_add(targets, preds)
    
    # Average loss and accuracy
    train_loss = train_err / train_batches
    train_accuracy = confusion_train.accuracy()
    cf_train = confusion_train.ret_mat()

    val_err = 0
    val_batches = 0
    confusion_valid = ConfusionMatrix(n_class)

    # Generate minibatches and validate on each one of them, same procedure as before
    for batch in iterate_minibatches(X_val.astype(np.float32), y_val.astype(np.int32), 
                                     mask_val.astype(np.float32), batch_size, shuffle=True):
        inputs, targets, in_masks = batch
        err, predict_val = val_fn(inputs, targets, in_masks)
        val_err += err
        val_batches += 1
        preds = np.argmax(predict_val, axis=-1)
        confusion_valid.batch_add(targets, preds)

    val_loss = val_err / val_batches
    val_accuracy = confusion_valid.accuracy()
    cf_val = confusion_valid.ret_mat()
    
    loss_training.append(train_loss)
    loss_validation.append(val_loss)
    acc_training.append(train_accuracy)
    acc_validation.append(val_accuracy)
    
    # Save the model parameters at the epoch with the lowest validation loss
    if min_val_loss > val_loss:
        min_val_loss = val_loss
        np.savez('params/CNN-LSTM_params.npz', *lasagne.layers.get_all_param_values(l_out))
    
    print("Epoch {} of {} time elapsed {:.3f}s".format(epoch + 1, num_epochs, time.time() - start_time))
    print("  training loss:\t\t{:.6f}".format(train_loss))
    print("  validation loss:\t\t{:.6f}".format(val_loss))
    print("  training accuracy:\t\t{:.2f} %".format(train_accuracy * 100))
    print("  validation accuracy:\t\t{:.2f} %".format(val_accuracy * 100))


Epoch 1 of 100 time elapsed 14.403s
  training loss:		2.253443
  validation loss:		2.123589
  training accuracy:		20.15 %
  validation accuracy:		23.75 %
Epoch 2 of 100 time elapsed 27.392s
  training loss:		2.104757
  validation loss:		2.036812
  training accuracy:		22.20 %
  validation accuracy:		22.03 %
Epoch 3 of 100 time elapsed 40.553s
  training loss:		2.054652
  validation loss:		1.961752
  training accuracy:		22.12 %
  validation accuracy:		25.78 %
Epoch 4 of 100 time elapsed 52.457s
  training loss:		1.987681
  validation loss:		1.815700
  training accuracy:		25.66 %
  validation accuracy:		34.69 %
Epoch 5 of 100 time elapsed 64.223s
  training loss:		1.848420
  validation loss:		1.722101
  training accuracy:		36.18 %
  validation accuracy:		41.25 %
Epoch 6 of 100 time elapsed 75.492s
  training loss:		1.725710
  validation loss:		1.612766
  training accuracy:		39.76 %
  validation accuracy:		41.41 %
Epoch 7 of 100 time elapsed 86.504s
  training loss:		1.652419
  validation loss:		1.644918
  training accuracy:		42.76 %
  validation accuracy:		36.41 %
Epoch 8 of 100 time elapsed 97.535s
  training loss:		1.671238
  validation loss:		1.613247
  training accuracy:		41.74 %
  validation accuracy:		43.28 %
Epoch 9 of 100 time elapsed 108.251s
  training loss:		1.628313
  validation loss:		1.474864
  training accuracy:		43.59 %
  validation accuracy:		42.34 %
Epoch 10 of 100 time elapsed 119.088s
  training loss:		1.550918
  validation loss:		1.443710
  training accuracy:		44.08 %
  validation accuracy:		47.50 %
Epoch 11 of 100 time elapsed 129.247s
  training loss:		1.509930
  validation loss:		1.386292
  training accuracy:		44.20 %
  validation accuracy:		47.97 %
Epoch 12 of 100 time elapsed 138.565s
  training loss:		1.462579
  validation loss:		1.353001
  training accuracy:		47.90 %
  validation accuracy:		49.53 %
Epoch 13 of 100 time elapsed 148.277s
  training loss:		1.429569
  validation loss:		1.307363
  training accuracy:		50.29 %
  validation accuracy:		53.28 %
Epoch 14 of 100 time elapsed 157.998s
  training loss:		1.392389
  validation loss:		1.259628
  training accuracy:		51.27 %
  validation accuracy:		55.94 %
Epoch 15 of 100 time elapsed 167.648s
  training loss:		1.325219
  validation loss:		1.222510
  training accuracy:		54.52 %
  validation accuracy:		55.62 %
Epoch 16 of 100 time elapsed 176.718s
  training loss:		1.321800
  validation loss:		1.186144
  training accuracy:		55.06 %
  validation accuracy:		58.44 %
Epoch 17 of 100 time elapsed 185.557s
  training loss:		1.249154
  validation loss:		1.157022
  training accuracy:		56.78 %
  validation accuracy:		56.41 %
Epoch 18 of 100 time elapsed 194.392s
  training loss:		1.196872
  validation loss:		1.058985
  training accuracy:		57.61 %
  validation accuracy:		61.25 %
Epoch 19 of 100 time elapsed 204.054s
  training loss:		1.118790
  validation loss:		1.022309
  training accuracy:		59.00 %
  validation accuracy:		62.34 %
Epoch 20 of 100 time elapsed 213.556s
  training loss:		1.139719
  validation loss:		0.992359
  training accuracy:		58.55 %
  validation accuracy:		62.50 %
Epoch 21 of 100 time elapsed 222.606s
  training loss:		1.121306
  validation loss:		1.009627
  training accuracy:		57.81 %
  validation accuracy:		59.06 %
Epoch 22 of 100 time elapsed 231.762s
  training loss:		1.076135
  validation loss:		0.962385
  training accuracy:		59.58 %
  validation accuracy:		61.72 %
Epoch 23 of 100 time elapsed 240.868s
  training loss:		1.069825
  validation loss:		0.999819
  training accuracy:		59.58 %
  validation accuracy:		62.34 %
Epoch 24 of 100 time elapsed 250.099s
  training loss:		1.035219
  validation loss:		0.956678
  training accuracy:		60.73 %
  validation accuracy:		64.22 %
Epoch 25 of 100 time elapsed 259.391s
  training loss:		1.017588
  validation loss:		0.942147
  training accuracy:		60.94 %
  validation accuracy:		62.66 %
Epoch 26 of 100 time elapsed 268.701s
  training loss:		1.011031
  validation loss:		0.930468
  training accuracy:		62.34 %
  validation accuracy:		65.00 %
Epoch 27 of 100 time elapsed 277.730s
  training loss:		1.005397
  validation loss:		0.926695
  training accuracy:		61.72 %
  validation accuracy:		66.72 %
Epoch 28 of 100 time elapsed 287.280s
  training loss:		0.990362
  validation loss:		0.905680
  training accuracy:		64.02 %
  validation accuracy:		65.00 %
Epoch 29 of 100 time elapsed 296.466s
  training loss:		0.972391
  validation loss:		0.905090
  training accuracy:		64.88 %
  validation accuracy:		65.62 %
Epoch 30 of 100 time elapsed 305.203s
  training loss:		0.989357
  validation loss:		0.907728
  training accuracy:		62.38 %
  validation accuracy:		68.28 %
Epoch 31 of 100 time elapsed 315.198s
  training loss:		0.959532
  validation loss:		0.882039
  training accuracy:		63.86 %
  validation accuracy:		66.41 %
Epoch 32 of 100 time elapsed 324.598s
  training loss:		0.949153
  validation loss:		0.952918
  training accuracy:		66.04 %
  validation accuracy:		64.84 %
Epoch 33 of 100 time elapsed 333.558s
  training loss:		0.946238
  validation loss:		0.925756
  training accuracy:		66.74 %
  validation accuracy:		67.81 %
Epoch 34 of 100 time elapsed 342.383s
  training loss:		0.964411
  validation loss:		0.894072
  training accuracy:		65.75 %
  validation accuracy:		71.09 %
Epoch 35 of 100 time elapsed 351.971s
  training loss:		0.945703
  validation loss:		0.843612
  training accuracy:		67.15 %
  validation accuracy:		67.81 %
Epoch 36 of 100 time elapsed 360.679s
  training loss:		0.874511
  validation loss:		0.859236
  training accuracy:		67.85 %
  validation accuracy:		70.94 %
Epoch 37 of 100 time elapsed 370.239s
  training loss:		0.876223
  validation loss:		0.791363
  training accuracy:		68.96 %
  validation accuracy:		72.03 %
Epoch 38 of 100 time elapsed 378.992s
  training loss:		0.866474
  validation loss:		0.858488
  training accuracy:		70.11 %
  validation accuracy:		70.47 %
Epoch 39 of 100 time elapsed 387.882s
  training loss:		0.872713
  validation loss:		0.845158
  training accuracy:		68.38 %
  validation accuracy:		69.38 %
Epoch 40 of 100 time elapsed 397.240s
  training loss:		0.886970
  validation loss:		0.831529
  training accuracy:		66.98 %
  validation accuracy:		68.75 %
Epoch 41 of 100 time elapsed 406.172s
  training loss:		0.899728
  validation loss:		0.920589
  training accuracy:		69.24 %
  validation accuracy:		65.31 %
Epoch 42 of 100 time elapsed 414.459s
  training loss:		0.877083
  validation loss:		0.878080
  training accuracy:		67.56 %
  validation accuracy:		69.22 %
Epoch 43 of 100 time elapsed 422.830s
  training loss:		0.887182
  validation loss:		0.809595
  training accuracy:		68.30 %
  validation accuracy:		69.84 %
Epoch 44 of 100 time elapsed 431.170s
  training loss:		0.808207
  validation loss:		0.810061
  training accuracy:		70.72 %
  validation accuracy:		70.16 %
Epoch 45 of 100 time elapsed 440.196s
  training loss:		0.804900
  validation loss:		0.791546
  training accuracy:		71.79 %
  validation accuracy:		70.00 %
Epoch 46 of 100 time elapsed 448.686s
  training loss:		0.791007
  validation loss:		0.794164
  training accuracy:		70.85 %
  validation accuracy:		71.56 %
Epoch 47 of 100 time elapsed 457.351s
  training loss:		0.824747
  validation loss:		0.771211
  training accuracy:		71.83 %
  validation accuracy:		73.75 %
Epoch 48 of 100 time elapsed 466.175s
  training loss:		0.810002
  validation loss:		0.763675
  training accuracy:		70.27 %
  validation accuracy:		73.59 %
Epoch 49 of 100 time elapsed 474.429s
  training loss:		0.773323
  validation loss:		0.811349
  training accuracy:		72.49 %
  validation accuracy:		72.81 %
Epoch 50 of 100 time elapsed 482.962s
  training loss:		0.811062
  validation loss:		0.784490
  training accuracy:		71.22 %
  validation accuracy:		72.34 %
Epoch 51 of 100 time elapsed 492.591s
  training loss:		0.784745
  validation loss:		0.852061
  training accuracy:		71.01 %
  validation accuracy:		69.69 %
Epoch 52 of 100 time elapsed 502.851s
  training loss:		0.802265
  validation loss:		0.773009
  training accuracy:		71.92 %
  validation accuracy:		72.97 %
Epoch 53 of 100 time elapsed 512.641s
  training loss:		0.896626
  validation loss:		0.799414
  training accuracy:		69.70 %
  validation accuracy:		70.62 %
Epoch 54 of 100 time elapsed 521.844s
  training loss:		0.795625
  validation loss:		0.805085
  training accuracy:		70.93 %
  validation accuracy:		71.25 %
Epoch 55 of 100 time elapsed 531.737s
  training loss:		0.776817
  validation loss:		0.822015
  training accuracy:		73.68 %
  validation accuracy:		74.53 %
Epoch 56 of 100 time elapsed 541.875s
  training loss:		0.745365
  validation loss:		0.843081
  training accuracy:		74.01 %
  validation accuracy:		74.38 %
Epoch 57 of 100 time elapsed 552.432s
  training loss:		0.733586
  validation loss:		0.756946
  training accuracy:		74.42 %
  validation accuracy:		74.69 %
Epoch 58 of 100 time elapsed 562.113s
  training loss:		0.743256
  validation loss:		0.764647
  training accuracy:		74.47 %
  validation accuracy:		72.97 %
Epoch 59 of 100 time elapsed 571.723s
  training loss:		0.756479
  validation loss:		0.743586
  training accuracy:		73.27 %
  validation accuracy:		75.31 %
Epoch 60 of 100 time elapsed 581.231s
  training loss:		0.747494
  validation loss:		0.806278
  training accuracy:		73.56 %
  validation accuracy:		73.28 %
Epoch 61 of 100 time elapsed 590.296s
  training loss:		0.769447
  validation loss:		0.836815
  training accuracy:		72.78 %
  validation accuracy:		73.59 %
Epoch 62 of 100 time elapsed 599.067s
  training loss:		0.733004
  validation loss:		0.763363
  training accuracy:		74.51 %
  validation accuracy:		75.62 %
Epoch 63 of 100 time elapsed 607.617s
  training loss:		0.703957
  validation loss:		0.723236
  training accuracy:		75.12 %
  validation accuracy:		75.78 %
Epoch 64 of 100 time elapsed 615.903s
  training loss:		0.696825
  validation loss:		0.760943
  training accuracy:		76.03 %
  validation accuracy:		75.94 %
Epoch 65 of 100 time elapsed 624.211s
  training loss:		0.712687
  validation loss:		0.739835
  training accuracy:		74.96 %
  validation accuracy:		74.53 %
Epoch 66 of 100 time elapsed 632.782s
  training loss:		0.679262
  validation loss:		0.715170
  training accuracy:		76.03 %
  validation accuracy:		75.62 %
Epoch 67 of 100 time elapsed 641.098s
  training loss:		0.653499
  validation loss:		0.734256
  training accuracy:		76.81 %
  validation accuracy:		75.94 %
Epoch 68 of 100 time elapsed 649.814s
  training loss:		0.636183
  validation loss:		0.713825
  training accuracy:		77.96 %
  validation accuracy:		77.19 %
Epoch 69 of 100 time elapsed 658.158s
  training loss:		0.675199
  validation loss:		0.763991
  training accuracy:		75.21 %
  validation accuracy:		73.44 %
Epoch 70 of 100 time elapsed 666.487s
  training loss:		0.686597
  validation loss:		0.762452
  training accuracy:		75.12 %
  validation accuracy:		74.84 %
Epoch 71 of 100 time elapsed 675.134s
  training loss:		0.707004
  validation loss:		0.754198
  training accuracy:		75.58 %
  validation accuracy:		75.62 %
Epoch 72 of 100 time elapsed 683.845s
  training loss:		0.682106
  validation loss:		0.715462
  training accuracy:		75.95 %
  validation accuracy:		74.69 %
Epoch 73 of 100 time elapsed 692.215s
  training loss:		0.651657
  validation loss:		0.773966
  training accuracy:		77.67 %
  validation accuracy:		75.16 %
Epoch 74 of 100 time elapsed 701.002s
  training loss:		0.624513
  validation loss:		0.714460
  training accuracy:		77.38 %
  validation accuracy:		76.72 %
Epoch 75 of 100 time elapsed 710.103s
  training loss:		0.601271
  validation loss:		0.752787
  training accuracy:		80.18 %
  validation accuracy:		75.62 %
Epoch 76 of 100 time elapsed 719.657s
  training loss:		0.595531
  validation loss:		0.703477
  training accuracy:		78.82 %
  validation accuracy:		77.19 %
Epoch 77 of 100 time elapsed 728.815s
  training loss:		0.598131
  validation loss:		0.693776
  training accuracy:		80.30 %
  validation accuracy:		77.81 %
Epoch 78 of 100 time elapsed 737.810s
  training loss:		0.600390
  validation loss:		0.748458
  training accuracy:		78.37 %
  validation accuracy:		76.41 %
Epoch 79 of 100 time elapsed 746.703s
  training loss:		0.582956
  validation loss:		0.729209
  training accuracy:		79.19 %
  validation accuracy:		77.03 %
Epoch 80 of 100 time elapsed 755.897s
  training loss:		0.566219
  validation loss:		0.736072
  training accuracy:		79.93 %
  validation accuracy:		77.97 %
Epoch 81 of 100 time elapsed 765.047s
  training loss:		0.587196
  validation loss:		0.755258
  training accuracy:		79.89 %
  validation accuracy:		76.56 %
Epoch 82 of 100 time elapsed 774.162s
  training loss:		0.615505
  validation loss:		0.716928
  training accuracy:		78.87 %
  validation accuracy:		78.91 %
Epoch 83 of 100 time elapsed 783.088s
  training loss:		0.569956
  validation loss:		0.733303
  training accuracy:		79.81 %
  validation accuracy:		77.66 %
Epoch 84 of 100 time elapsed 792.493s
  training loss:		0.579606
  validation loss:		0.728263
  training accuracy:		80.72 %
  validation accuracy:		78.28 %
Epoch 85 of 100 time elapsed 801.732s
  training loss:		0.572313
  validation loss:		0.697558
  training accuracy:		80.55 %
  validation accuracy:		78.59 %
Epoch 86 of 100 time elapsed 810.887s
  training loss:		0.587405
  validation loss:		0.758074
  training accuracy:		80.76 %
  validation accuracy:		77.50 %
Epoch 87 of 100 time elapsed 819.758s
  training loss:		0.603486
  validation loss:		0.715273
  training accuracy:		79.40 %
  validation accuracy:		78.44 %
Epoch 88 of 100 time elapsed 829.092s
  training loss:		0.553078
  validation loss:		0.682315
  training accuracy:		81.25 %
  validation accuracy:		79.22 %
Epoch 89 of 100 time elapsed 837.659s
  training loss:		0.559393
  validation loss:		0.710943
  training accuracy:		81.46 %
  validation accuracy:		79.22 %
Epoch 90 of 100 time elapsed 847.592s
  training loss:		0.574693
  validation loss:		0.682275
  training accuracy:		80.35 %
  validation accuracy:		80.31 %
Epoch 91 of 100 time elapsed 856.709s
  training loss:		0.535498
  validation loss:		0.686187
  training accuracy:		81.70 %
  validation accuracy:		78.91 %
Epoch 92 of 100 time elapsed 865.102s
  training loss:		0.523886
  validation loss:		0.709255
  training accuracy:		82.07 %
  validation accuracy:		79.06 %
Epoch 93 of 100 time elapsed 873.432s
  training loss:		0.504924
  validation loss:		0.711597
  training accuracy:		82.52 %
  validation accuracy:		78.28 %
Epoch 94 of 100 time elapsed 881.772s
  training loss:		0.499599
  validation loss:		0.724033
  training accuracy:		83.39 %
  validation accuracy:		79.69 %
Epoch 95 of 100 time elapsed 890.391s
  training loss:		0.555530
  validation loss:		0.682140
  training accuracy:		80.88 %
  validation accuracy:		80.31 %
Epoch 96 of 100 time elapsed 899.381s
  training loss:		0.506254
  validation loss:		0.754861
  training accuracy:		83.18 %
  validation accuracy:		78.91 %
Epoch 97 of 100 time elapsed 908.532s
  training loss:		0.493507
  validation loss:		0.673616
  training accuracy:		83.10 %
  validation accuracy:		81.41 %
Epoch 98 of 100 time elapsed 916.907s
  training loss:		0.497518
  validation loss:		0.691903
  training accuracy:		82.65 %
  validation accuracy:		81.09 %
Epoch 99 of 100 time elapsed 925.364s
  training loss:		0.514264
  validation loss:		0.727489
  training accuracy:		83.26 %
  validation accuracy:		79.84 %
Epoch 100 of 100 time elapsed 933.698s
  training loss:		0.489799
  validation loss:		0.683398
  training accuracy:		83.59 %
  validation accuracy:		80.31 %

In [12]:
print("Minimum validation loss: {:.6f}".format(min_val_loss))


Minimum validation loss: 0.673616

Model loss and accuracy

Here we plot the loss and the accuracy for the training and validation set at each epoch.


In [13]:
# Plot training and validation loss
x_axis = range(num_epochs)
plt.figure(figsize=(8,6))
plt.plot(x_axis,loss_training)
plt.plot(x_axis,loss_validation)
plt.xlabel('Epoch')
plt.ylabel('Error')
plt.legend(('Training','Validation'));



In [14]:
plt.figure(figsize=(8,6))
plt.plot(x_axis,acc_training)
plt.plot(x_axis,acc_validation)
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend(('Training','Validation'));


Confusion matrix

The confusion matrix allows us to visualize how well is predicted each class and which are the most common misclassifications.


In [15]:
# Plot confusion matrix 
# Code based on http://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html

plt.figure(figsize=(8,8))
cmap=plt.cm.Blues   
plt.imshow(cf_val, interpolation='nearest', cmap=cmap)
plt.title('Confusion matrix validation set')
plt.colorbar()
tick_marks = np.arange(n_class)
classes = ['Nucleus','Cytoplasm','Extracellular','Mitochondrion','Cell membrane','ER',
           'Chloroplast','Golgi apparatus','Lysosome','Vacuole']

plt.xticks(tick_marks, classes, rotation=60)
plt.yticks(tick_marks, classes)

thresh = cf_val.max() / 2.
for i, j in itertools.product(range(cf_val.shape[0]), range(cf_val.shape[1])):
    plt.text(j, i, cf_val[i, j],
             horizontalalignment="center",
             color="white" if cf_val[i, j] > thresh else "black")

plt.tight_layout()
plt.ylabel('True location')
plt.xlabel('Predicted location');



In [ ]: